import math

import torch
import torch.nn as nn

from torch.nn.utils.parametrizations import orthogonal

import geotorch

from lib.lorentz.manifold import CustomLorentz
from lib.geoopt.tensor import ManifoldParameter


class LFC_hyperweight(nn.Module):
    """
        Modified Lorentz fully connected layer of Chen et al. (2022).

        Code modified from https://github.com/chenweize1998/fully-hyperbolic-nn

        args:
            manifold: Instance of Lorentz manifold
            in_features, out_features, bias: Same as nn.Linear
            init_scale: Scale parameter for internal normalization
            learn_scale: If scale parameter should be learnable
            normalize: If internal normalization should be applied
    """

    def __init__(
            self,
            manifold: CustomLorentz,
            in_features,
            out_features,
            bias=False,
            init_scale=None,
            learn_scale=False,
            normalize=False
        ):
        super(LFC_hyperweight, self).__init__()
        self.manifold = manifold
        self.in_features = in_features
        self.out_features = out_features
        self.bias = bias
        self.normalize = normalize

        self.init_std = 0.02
        initial = (-2*self.init_std)* torch.rand((self.out_features, self.in_features)) + self.init_std

        self.w = ManifoldParameter(self.manifold.projx(initial), manifold=self.manifold)
        #self.weight = nn.Linear(self.in_features, self.out_features, bias=bias)

        self.reset_parameters()

        # Scale for internal normalization
        if init_scale is not None:
            self.scale = nn.Parameter(torch.ones(()) * init_scale, requires_grad=learn_scale)
        else:
            self.scale = nn.Parameter(torch.ones(()) * 2.3, requires_grad=learn_scale)

    def forward(self, x):

        x = torch.nn.functional.linear(x, self.w, bias=None)
        x_space = x.narrow(-1, 1, x.shape[-1] - 1)

        if self.normalize:
            scale = x.narrow(-1, 0, 1).sigmoid() * self.scale.exp()
            square_norm = (x_space * x_space).sum(dim=-1, keepdim=True)

            mask = square_norm <= 1e-10

            square_norm[mask] = 1
            unit_length = x_space/torch.sqrt(square_norm)
            x_space = scale*unit_length

            x_time = torch.sqrt(scale**2 + self.manifold.k + 1e-5)
            x_time = x_time.masked_fill(mask, self.manifold.k.sqrt())

            mask = mask==False
            x_space = x_space * mask

            x = torch.cat([x_time, x_space], dim=-1)
        else:
            x = self.manifold.add_time(x_space)

        return x

    def reset_parameters(self):
        return
        # nn.init.uniform_(self.w, -self.init_std, self.init_std)

        #if self.bias:
        #    nn.init.constant_(self.weight.bias, 0)


class LorentzTransform(torch.nn.Module):
    def __init__(self, dim, manifold, mode="both"):
        super(LorentzTransform, self).__init__()

        self.dim = dim

        self.boost = self.rotate = False

        if mode == "both" or mode == "boost":
            self.v = nn.Parameter(torch.rand((dim - 1, 1)))
            self.boost = True

        if mode == "both" or mode == "rotate":
            self.rotation_weight = nn.Parameter(torch.rand((dim - 1, dim - 1)))
            self.rotate = True

        self.eye = nn.Parameter(torch.eye(dim - 1), requires_grad=False)
        self.manifold = manifold

        self.reset_parameters()

    def forward(self, x, stabalize=False):

        if self.boost:
            norm = self.v.norm(2, dim=0, keepdim=False)
            desired = torch.clamp(norm, max=0.95)
            # desired = torch.sigmoid(norm)
            v = self.v * (desired / norm)

            # get boost
            gamma = 1 / torch.sqrt(1 - torch.norm(v) ** 2).reshape(1, -1)
            el_1 = -gamma * v.T
            el_2 = -gamma * v
            el_3 = self.eye + (gamma - 1) * (v * v.T) / (v.norm(2, dim=0, keepdim=True) ** 2)

            upper = torch.cat([gamma, el_1], dim=1)
            lower = torch.cat([el_2, el_3], dim=1)
            boost = torch.cat([upper, lower], dim=0)

        # get rotation
        if self.rotate:
            rotation = torch.nn.functional.pad(self.rotation_weight, (1, 0, 1, 0))
            rotation[..., 0, 0] = 1


        if self.rotate and self.boost:
            output = torch.matmul(torch.matmul(rotation, boost), x.transpose(-1, -2)).transpose(-1, -2)
        elif self.rotate:
            output = torch.matmul(rotation, x.transpose(-1, -2)).transpose(-1, -2)
        elif self.boost:
            output = torch.matmul(boost, x.transpose(-1, -2)).transpose(-1, -2)

        if stabalize:
            output = self.manifold.logmap0(output)
            norm = output[..., 1:].norm(2, dim=-1, keepdim=True)
            desired = torch.clamp(norm, max=10)

            output = output[..., 1:] * (desired / norm)
            output = self.manifold.add_time(output)

            output = self.manifold.expmap0(output)

        return output

    def reset_parameters(self):
        return
        # nn.init.kaiming_normal_(self.v)
        # nn.init.orthogonal_(self.rotation_weight)


class LorentzFullyConnected_transform(nn.Module):
    """
        Modified Lorentz fully connected layer of Chen et al. (2022).

        Code modified from https://github.com/chenweize1998/fully-hyperbolic-nn

        args:
            manifold: Instance of Lorentz manifold
            in_features, out_features, bias: Same as nn.Linear
            init_scale: Scale parameter for internal normalization
            learn_scale: If scale parameter should be learnable
            normalize: If internal normalization should be applied
    """

    def __init__(
            self,
            manifold: CustomLorentz,
            in_features,
            out_features,
            bias=False,
            init_scale=None,
            learn_scale=False,
            normalize=False
    ):
        super(LorentzFullyConnected_transform, self).__init__()
        self.manifold = manifold
        self.in_features = in_features
        self.out_features = out_features
        self.bias = bias
        self.normalize = normalize

        self.weight = nn.Linear(self.in_features - 1, self.out_features - 1, bias=bias)

        self.init_std = 0.02
        self.reset_parameters()

        self.shape_matrix = nn.Parameter(torch.ones((in_features, out_features)), requires_grad=False)

        self.transform = LorentzTransform(out_features, self.manifold)
        self.transform = orthogonal(self.transform, name="rotation_weight")

    def forward(self, x):

        if self.out_features != self.in_features:
            x_space = self.weight(x[..., 1:])
            x = self.manifold.add_time(x_space)

        return self.transform(x)

    def reset_parameters(self):
        nn.init.uniform_(self.weight.weight, -self.init_std, self.init_std)

        if self.bias:
            nn.init.constant_(self.weight.bias, 0)


class LorentzBoost(nn.Module):
    """hyperbolic rotaion achieved by times A = [cosh\alpha,...,sinh\alpha]
                                                [sinh\alpha,...,cosh\alpha]
    """
    def __init__(self, manifold):
        super().__init__()
        self.manifold = manifold
        self.weight = nn.Parameter(torch.FloatTensor(1))

    def forward(self, x): # x =[x_0,x_1,...,x_n]
        x_narrow = x.narrow(-1, 1, x.shape[-1] - 2) #x_narrow = [x_1,...,x_n-1]
        # x_0 = torch.cosh(self.weight) * x.narrow(-1, 0, 1) + torch.sinh(self.weight) * x.narrow(-1, x.shape[-1] - 1, 1)
        # x_n = torch.sinh(self.weight) * x.narrow(-1, 0, 1) + torch.cosh(self.weight) * x.narrow(-1, x.shape[-1] - 1, 1)

        x_0 = torch.sqrt(self.weight**2 + 1.0) * x_narrow.narrow(-1, 0, 1) + self.weight * x_narrow.narrow(-1, x_narrow.shape[-1] - 1, 1)
        x_n = self.weight * x_narrow.narrow(-1, 0, 1) + torch.sqrt(self.weight**2 + 1.0) * x_narrow.narrow(-1, x_narrow.shape[-1] - 1, 1)
        x = torch.cat([x_0, x_narrow, x_n], dim=-1)

        return x

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.weight, gain=math.sqrt(2))


class LorentzRotation(nn.Module):
    def __init__(self,
                 manifold,
                 in_features,
                 out_features,
                 if_dropout=False,
                 dropout=0,
                 if_regularize = False,
                 if_projected = False
                 ):
        super().__init__()
        self.manifold = manifold
        self.in_features = in_features
        self.out_features = out_features
        self.linear = nn.Linear(self.in_features-1, self.out_features-1,bias =False)
        self.linear = orthogonal(self.linear,"weight", orthogonal_map="cayley")
        self.reset_parameters()
        self.if_regularize = if_regularize
        self.if_projected = if_projected

    def forward(self, x):

        x_0 = x.narrow(-1, 0, 1)
        x_narrow = x.narrow(-1, 1, x.shape[-1] - 1)

        x_ = self.linear(x_narrow)
        x = torch.cat([x_0, x_], dim=-1)
        if self.if_regularize is True:
            x = self.manifold.regularize(x)

        if self.if_projected is True:
            x = self.manifold.projx(x)

        return x

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.out_features)
        step = self.in_features
        nn.init.uniform_(self.linear.weight, -stdv, stdv)
        with torch.no_grad():
            for idx in range(0, self.in_features, step):
                self.linear.weight[:, idx] = 0


class LorentzProjection(nn.Module):
    """
    Hyperbolic graph convolution layer.
    """

    def __init__(self, manifold, in_features, out_features, dropout=False):
        super(LorentzProjection, self).__init__()
        self.rotation = LorentzRotation(manifold, in_features, in_features, if_dropout=False,if_regularize=False,if_projected=False)
        self.boost = LorentzBoost(manifold)
        self.projection = LorentzRotation(manifold, in_features, out_features,if_dropout=True, dropout=dropout,if_regularize=True,if_projected=True)
        #self.projection = nn.Linear(in_features, out_features)
    def forward(self, input):
        xt = input
        xt = self.rotation(xt)
        xt = self.boost(xt)
        h = self.projection(xt)

        return h

